# -*- coding: utf-8 -*-
"""
Code Introduction:
    This code 
Version History:
    Created: Fri Dec 18 21:54:30 2020
    Current: 

@author: Xing Guo (guoxing.econ@gmail.com)

"""
#%% Preliminaries
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.backends.backend_pdf as figpdf
import datetime
import matplotlib.dates as matdates
from matplotlib.ticker import MaxNLocator
import matplotlib.ticker as TickerFun


#%% Basic Classes

class Line:
    def __init__(self,Color='black',Style='solid',Width=3,Marker='',MarkerSize=6):
        self.Color = Color
        self.Style = Style
        self.Width = Width
        self.Marker = Marker
        self.MarkerSize = MarkerSize
    
    def Plot(self,x,y,ax=None,Label=None):
        if ax==None:
            ax = plt.gca()
        if Label==None:
            ax.plot(x,y,linestyle=self.Style,linewidth=self.Width,color=self.Color)
        else:
            ax.plot(x,y,linestyle=self.Style,linewidth=self.Width,color=self.Color,label=Label)
        return ax

class Area:
    def __init__(self,Color='gray',Alpha=0.2):
        self.Color = Color
        self.Alpha = Alpha
    
    def Plot(self,  x,y_low,y_upp,ax=None):
        if ax==None:
            ax = plt.gca()
        ax.fill_between(x,y_upp,y_low,facecolor=self.Color,alpha=self.Alpha)
        return ax
     
        
class IRF:
    def __init__(self,IrfVar,TimeVar='',Line=Line(),Area_Flag=None,Area_Param=[],Area_Info=Area()):
        self.IrfVar = IrfVar
        self.TimeVar = TimeVar
        self.Line = Line
        self.Area_Flag = Area_Flag
        if Area_Flag!=None:
            self.Area_Param = Area_Param
            self.Area_Info = Area_Info
    
    def Plot(self,DS,ax=None):
        if ax==None:
            ax = plt.gca()
        ## X-Axis
        XAxis = DS.index.to_series() if self.TimeVar=='' else DS[self.TimeVar]
        xlim = (XAxis.min(),XAxis.max())
        ## Y-Axis
        ErrorBand = True
        if self.Area_Flag=='Std':
            Temp_Std = DS[self.Area_Param[0]]
            Temp_Width = self.Area_Param[1]
            Temp_Low = DS[self.IrfVar]-Temp_Width*Temp_Std
            Temp_Upp = DS[self.IrfVar]+Temp_Width*Temp_Std
        elif self.Area_Flag=='CI':
            Temp_Low = DS[self.Area_Param[0]]
            Temp_Upp = DS[self.Area_Param[1]]
        else:
            ErrorBand = False
        ax = self.Line.Plot(XAxis,DS[self.IrfVar])
        
        if ErrorBand:
            ax = self.Area_Info.Plot(XAxis,Temp_Low,Temp_Upp)
            ylim = (min(xx.dropna().values.min() for xx in [DS[self.IrfVar],Temp_Low,Temp_Upp]), \
                    max(xx.dropna().values.max() for xx in [DS[self.IrfVar],Temp_Low,Temp_Upp]))
        else:
            ylim = (DS[self.IrfVar].dropna().values.min(),DS[self.IrfVar].dropna().values.max())
        
        Scale = {'x':xlim,'y':ylim}
        return ax, Scale

#%% Basic Functions
def Setup_Fig(FigSize=(1/2,1/3),FigDpi=200):
    A4Size          =   [8+1/4,11+3/4]
    FigSize         =   (FigSize[0]*A4Size[0],FigSize[1]*A4Size[0])
    FIG_Handle      =   plt.figure(figsize=FigSize,dpi=FigDpi)
    plt.rc('text',usetex=True)
    plt.rc('font',family='serif')
    return FIG_Handle

def Setup_Ax_IRF(ax=None,XLimit=(np.NAN,np.NAN),YLimit=(np.NAN,np.NAN),Symmetry=False,XTickStep=4):
    if ax==None:
        ax = plt.gca()
    ## Scale
    # X-Axis
    XLimit_0 = ax.get_xlim()
    if ~np.isfinite(XLimit[0]) or XLimit[0]>=XLimit_0[0]:
        XMin = XLimit_0[0]
    else:
        XMin = XLimit[0]
    if ~np.isfinite(XLimit[1]) or XLimit[1]<=XLimit_0[1]:
        XMax = XLimit_0[1]
    else:
        XMax = XLimit[1]
    ax.set_xlim(XMin,XMax)        
    # Y-Axis
    YLimit_0 = ax.get_ylim()
    if ~np.isfinite(YLimit[0]) or YLimit[0]>=YLimit_0[0]:
        YMin = YLimit_0[0]
    else:
        YMin = YLimit[0]
    if ~np.isfinite(YLimit[1]) or YLimit[1]<=YLimit_0[1]:
        YMax = YLimit_0[1]
    else:
        YMax = YLimit[1]
        
    if Symmetry:
        YAbsMax = max(abs(YMin),abs(YMax))
        YMin = -YAbsMax
        YMax = YAbsMax
    ax.set_ylim(YMin,YMax)   
    
    ## Interval
    # X-Axis
    ax.xaxis.set_major_locator(TickerFun.MaxNLocator(steps=[XTickStep]))
    # Y-Axis
    ax.yaxis.set_major_locator(TickerFun.MaxNLocator(nbins=7,steps=[2,5,10]))

    ## Invisible Top and Right borders
    # close the borders
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
    # delete the ticks
    ax.xaxis.set_ticks_position('bottom')
    ax.yaxis.set_ticks_position('left')
    
    ## Format the Bottom and Left Borders
    ax.spines['bottom'].set_bounds(max(np.min(ax.get_xticks()),ax.get_xlim()[0]), \
                                   min(np.max(ax.get_xticks()),ax.get_xlim()[1]))
    ax.spines['left'].set_bounds(max(np.min(ax.get_yticks()),ax.get_ylim()[0]), \
                                 min(np.max(ax.get_yticks()),ax.get_ylim()[1]))
    ## Benchmark Line
    if (YMax>0) & (YMin<0):
        plt.axhline(y=0, linewidth=1,color='k',alpha=0.5)
    ## Add the Gridlines
    ax.grid('on',axis='both',linestyle='dashed',alpha=0.5) 
    return ax

'''
Function to setup the axis
'''
def Setup_Ax(ax=None,XLimit=(np.NAN,np.NAN),YLimit=(np.NAN,np.NAN),Symmetry=False, \
             XTickStep=[1,2,5],XTickNbins=5,YTickStep=[1,2,5],YTickNbins=5,XDateFormatter='%y'):
    if ax==None:
        ax = plt.gca()
    ## Scale
    # X-Axis
    XLimit_0 = ax.get_xlim()
    if ~np.isfinite(XLimit[0]) or XLimit[0]>=XLimit_0[0]:
        XMin = XLimit_0[0]
    else:
        XMin = XLimit[0]
    if ~np.isfinite(XLimit[1]) or XLimit[1]<=XLimit_0[1]:
        XMax = XLimit_0[1]
    else:
        XMax = XLimit[1]
    ax.set_xlim(XMin,XMax)        
    # Y-Axis
    YLimit_0 = ax.get_ylim()
    if ~np.isfinite(YLimit[0]) or YLimit[0]>=YLimit_0[0]:
        YMin = YLimit_0[0]
    else:
        YMin = YLimit[0]
    if ~np.isfinite(YLimit[1]) or YLimit[1]<=YLimit_0[1]:
        YMax = YLimit_0[1]
    else:
        YMax = YLimit[1]
        
    if Symmetry:
        YAbsMax = max(abs(YMin),abs(YMax))
        YMin = -YAbsMax
        YMax = YAbsMax
    ax.set_ylim(YMin,YMax)   
    
    ## Interval
    # X-Axis
    ax.xaxis.set_major_locator(MaxNLocator(nbins=XTickNbins, steps=XTickStep))
    # Y-Axis
    ax.yaxis.set_major_locator(MaxNLocator(nbins=YTickNbins, steps=YTickStep, symmetric=Symmetry))

    ## Invisible Top and Right borders
    # close the borders
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
    # delete the ticks
    ax.xaxis.set_ticks_position('bottom')
    ax.yaxis.set_ticks_position('left')
    
    ## Format the Bottom and Left Borders
    ax.spines['bottom'].set_bounds(max(np.min(ax.get_xticks()),ax.get_xlim()[0]), \
                                   min(np.max(ax.get_xticks()),ax.get_xlim()[1]))
    ax.spines['left'].set_bounds(max(np.min(ax.get_yticks()),ax.get_ylim()[0]), \
                                 min(np.max(ax.get_yticks()),ax.get_ylim()[1]))
    ## Benchmark Line
    if (YMax>0) & (YMin<0):
        plt.axhline(y=0, linewidth=1,color=MyColor('Black',0.5))
    ## Add the Gridlines
    ax.grid('on',axis='both',linestyle='dashed',color=MyColor('Gray',0.5)) 
    return ax
'''
Function for Common Color Themes
'''
def MyColor(ColorStr='Blue',ColorDepth=1):
    if ColorStr=='Blue':
        ColorRGB = (0,90,155)
    elif ColorStr=='Red':
        ColorRGB = (134,27,57)
    elif ColorStr=='Green':
        ColorRGB = (0,158,115)
    elif ColorStr=='Orange':
        ColorRGB = (230,159,0)
    elif ColorStr=='Yellow':
        ColorRGB = (240,228,66)
    elif ColorStr=='Purple':
        ColorRGB = (204,121,167)
    elif ColorStr=='Black':
        ColorRGB = (0,0,0)
    elif ColorStr=='Gray':
        ColorRGB = (180,180,180)
    else:
        raise Exception("The color is not specified...")
    
    ColorRGB = tuple(1*(1-ColorDepth)+(cc/255)*ColorDepth for cc in ColorRGB)

    return ColorRGB
    
#%% IRF Plot
def Plot_IRF(DS,IRFs,ax=None,XLimit=(np.NAN,np.NAN),YLimit=(np.NAN,np.NAN),Symmetry=False,XTickStep=4):
    if ax==None:
        ax = plt.gca()
    if not type(IRFs)==list:
        ax, Scale = IRFs.Plot(DS)
    else:
        for ii,irf in enumerate(IRFs):
            ax, TempScale = irf.Plot(DS)
            if ii==0:
                Scale = TempScale
            else:
                Scale = {'x': (min(Scale['x'][0],TempScale['x'][0]), \
                               max(Scale['x'][1],TempScale['x'][1])), \
                         'y': (min(Scale['y'][0],TempScale['y'][0]), \
                               max(Scale['y'][1],TempScale['y'][1]))}
    ax = Setup_Ax_IRF(ax=ax,XLimit=Scale['x'],YLimit=Scale['y'])
    ax = Setup_Ax_IRF(ax=ax,XLimit=XLimit,YLimit=YLimit,Symmetry=Symmetry,XTickStep=XTickStep)
    
    return ax            
    
#%% Color Bar Plot
def Plot_VBar(x,Area=Area(),ax=None):
    if ax==None:
        ax = plt.gca()
    y_low = [ax.get_ylim()[0],ax.get_ylim()[0]]
    y_upp = [ax.get_ylim()[1],ax.get_ylim()[1]]
    ax = Area.Plot(x,y_low,y_upp,ax=ax)

    return ax
#%% Test
# exec(open('Step_2_Regression_National.py').read())
# DS_MS = pickle.load(open('TempData\DS_MonetaryPolicy.p','rb'))
# DS_DepVar = pickle.load(open('TempData\DS_DepVar_National.p','rb'))
# DS = DS_MS.join(DS_DepVar).sort_index().reset_index()
# LP_0 = LP_Info('InterestRate',ControlVarList=[])
# Reg_0 = LocalProjection(DS,LP_0)

# H_fig = plt.figure()
# ax = H_fig.add_subplot(1,1,1)
# IRF_1 = IRF('Coef',Area_Flag='CI',Area_Param=['CI_Left','CI_Right'])
# # ax,Scale = IRF_1.Plot(ax,Reg_0)

# Plot_IRF(Reg_0,IRF_1)